{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from onnx import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Int attribute:\n",
      "\n",
      "name: \"this_is_an_int\"\n",
      "i: 1701\n",
      "type: INT\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Int Attibute\n",
    "arg = helper.make_attribute(\"this_is_an_int\", 1701)\n",
    "print(\"\\nInt attribute:\\n\")\n",
    "print(arg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Float attribute:\n",
      "\n",
      "name: \"this_is_a_float\"\n",
      "f: 3.1400001049\n",
      "type: 1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#NBVAL_IGNORE_OUTPUT\n",
    "# Float Attribute\n",
    "arg = helper.make_attribute(\"this_is_a_float\", 3.14)\n",
    "print(\"\\nFloat attribute:\\n\")\n",
    "print(arg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "String attribute:\n",
      "\n",
      "name: \"this_is_a_string\"\n",
      "s: \"string_content\"\n",
      "type: STRING\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# String Attribute\n",
    "arg = helper.make_attribute(\"this_is_a_string\", \"string_content\")\n",
    "print(\"\\nString attribute:\\n\")\n",
    "print(arg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Repeated int attribute:\n",
      "\n",
      "name: \"this_is_a_repeated_int\"\n",
      "ints: 1\n",
      "ints: 2\n",
      "ints: 3\n",
      "ints: 4\n",
      "type: INTS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Repeated Attribute\n",
    "arg = helper.make_attribute(\"this_is_a_repeated_int\", [1, 2, 3, 4])\n",
    "print(\"\\nRepeated int attribute:\\n\")\n",
    "print(arg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NodeProto:\n",
      "\n",
      "input: \"X\"\n",
      "output: \"Y\"\n",
      "op_type: \"Relu\"\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# node\n",
    "node_proto = helper.make_node(\"Relu\", [\"X\"], [\"Y\"])\n",
    "\n",
    "print(\"\\nNodeProto:\\n\")\n",
    "print(node_proto)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NodeProto:\n",
      "\n",
      "input: \"X\"\n",
      "input: \"W\"\n",
      "input: \"B\"\n",
      "output: \"Y\"\n",
      "op_type: \"Conv\"\n",
      "attribute {\n",
      "  name: \"kernel\"\n",
      "  i: 3\n",
      "  type: INT\n",
      "}\n",
      "attribute {\n",
      "  name: \"pad\"\n",
      "  i: 1\n",
      "  type: INT\n",
      "}\n",
      "attribute {\n",
      "  name: \"stride\"\n",
      "  i: 1\n",
      "  type: INT\n",
      "}\n",
      "\n",
      "\n",
      "More Readable NodeProto (no args yet):\n",
      "\n",
      "%Y = Conv[kernel = 3, pad = 1, stride = 1](%X, %W, %B)\n"
     ]
    }
   ],
   "source": [
    "# node with args\n",
    "node_proto = helper.make_node(\n",
    "    \"Conv\", [\"X\", \"W\", \"B\"], [\"Y\"],\n",
    "    kernel=3, stride=1, pad=1)\n",
    "\n",
    "# This is just for making the attributes to be printed in order\n",
    "node_proto.attribute.sort(key=lambda attr: attr.name)\n",
    "print(\"\\nNodeProto:\\n\")\n",
    "print(node_proto)\n",
    "\n",
    "print(\"\\nMore Readable NodeProto (no args yet):\\n\")\n",
    "print(helper.printable_node(node_proto))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "graph proto:\n",
      "\n",
      "node {\n",
      "  input: \"X\"\n",
      "  input: \"W1\"\n",
      "  input: \"B1\"\n",
      "  output: \"H1\"\n",
      "  op_type: \"FC\"\n",
      "}\n",
      "node {\n",
      "  input: \"H1\"\n",
      "  output: \"R1\"\n",
      "  op_type: \"Relu\"\n",
      "}\n",
      "node {\n",
      "  input: \"R1\"\n",
      "  input: \"W2\"\n",
      "  input: \"B2\"\n",
      "  output: \"Y\"\n",
      "  op_type: \"FC\"\n",
      "}\n",
      "name: \"MLP\"\n",
      "input {\n",
      "  name: \"X\"\n",
      "  type {\n",
      "    tensor_type {\n",
      "      elem_type: 1\n",
      "      shape {\n",
      "        dim {\n",
      "          dim_value: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "input {\n",
      "  name: \"W1\"\n",
      "  type {\n",
      "    tensor_type {\n",
      "      elem_type: 1\n",
      "      shape {\n",
      "        dim {\n",
      "          dim_value: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "input {\n",
      "  name: \"B1\"\n",
      "  type {\n",
      "    tensor_type {\n",
      "      elem_type: 1\n",
      "      shape {\n",
      "        dim {\n",
      "          dim_value: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "input {\n",
      "  name: \"W2\"\n",
      "  type {\n",
      "    tensor_type {\n",
      "      elem_type: 1\n",
      "      shape {\n",
      "        dim {\n",
      "          dim_value: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "input {\n",
      "  name: \"B2\"\n",
      "  type {\n",
      "    tensor_type {\n",
      "      elem_type: 1\n",
      "      shape {\n",
      "        dim {\n",
      "          dim_value: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "output {\n",
      "  name: \"Y\"\n",
      "  type {\n",
      "    tensor_type {\n",
      "      elem_type: 1\n",
      "      shape {\n",
      "        dim {\n",
      "          dim_value: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n",
      "\n",
      "More Readable GraphProto:\n",
      "\n",
      "graph MLP (\n",
      "  %X[FLOAT, 1]\n",
      "  %W1[FLOAT, 1]\n",
      "  %B1[FLOAT, 1]\n",
      "  %W2[FLOAT, 1]\n",
      "  %B2[FLOAT, 1]\n",
      ") {\n",
      "  %H1 = FC(%X, %W1, %B1)\n",
      "  %R1 = Relu(%H1)\n",
      "  %Y = FC(%R1, %W2, %B2)\n",
      "  return %Y\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# graph\n",
    "graph_proto = helper.make_graph(\n",
    "    [\n",
    "        helper.make_node(\"FC\", [\"X\", \"W1\", \"B1\"], [\"H1\"]),\n",
    "        helper.make_node(\"Relu\", [\"H1\"], [\"R1\"]),\n",
    "        helper.make_node(\"FC\", [\"R1\", \"W2\", \"B2\"], [\"Y\"]),\n",
    "    ],\n",
    "    \"MLP\",\n",
    "    [\n",
    "        helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('W1', TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('B1', TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('W2', TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('B2', TensorProto.FLOAT, [1]),\n",
    "    ],\n",
    "    [\n",
    "        helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "print(\"\\ngraph proto:\\n\")\n",
    "print(graph_proto)\n",
    "\n",
    "print(\"\\nMore Readable GraphProto:\\n\")\n",
    "print(helper.printable_graph(graph_proto))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NodeProto that contains a graph:\n",
      "\n",
      "input: \"Input\"\n",
      "input: \"W1\"\n",
      "input: \"B1\"\n",
      "input: \"W2\"\n",
      "input: \"B2\"\n",
      "output: \"Output\"\n",
      "op_type: \"graph\"\n",
      "attribute {\n",
      "  name: \"graph\"\n",
      "  graphs {\n",
      "    node {\n",
      "      input: \"X\"\n",
      "      input: \"W1\"\n",
      "      input: \"B1\"\n",
      "      output: \"H1\"\n",
      "      op_type: \"FC\"\n",
      "    }\n",
      "    node {\n",
      "      input: \"H1\"\n",
      "      output: \"R1\"\n",
      "      op_type: \"Relu\"\n",
      "    }\n",
      "    node {\n",
      "      input: \"R1\"\n",
      "      input: \"W2\"\n",
      "      input: \"B2\"\n",
      "      output: \"Y\"\n",
      "      op_type: \"FC\"\n",
      "    }\n",
      "    name: \"MLP\"\n",
      "    input {\n",
      "      name: \"X\"\n",
      "      type {\n",
      "        tensor_type {\n",
      "          elem_type: 1\n",
      "          shape {\n",
      "            dim {\n",
      "              dim_value: 1\n",
      "            }\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "    input {\n",
      "      name: \"W1\"\n",
      "      type {\n",
      "        tensor_type {\n",
      "          elem_type: 1\n",
      "          shape {\n",
      "            dim {\n",
      "              dim_value: 1\n",
      "            }\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "    input {\n",
      "      name: \"B1\"\n",
      "      type {\n",
      "        tensor_type {\n",
      "          elem_type: 1\n",
      "          shape {\n",
      "            dim {\n",
      "              dim_value: 1\n",
      "            }\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "    input {\n",
      "      name: \"W2\"\n",
      "      type {\n",
      "        tensor_type {\n",
      "          elem_type: 1\n",
      "          shape {\n",
      "            dim {\n",
      "              dim_value: 1\n",
      "            }\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "    input {\n",
      "      name: \"B2\"\n",
      "      type {\n",
      "        tensor_type {\n",
      "          elem_type: 1\n",
      "          shape {\n",
      "            dim {\n",
      "              dim_value: 1\n",
      "            }\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "    output {\n",
      "      name: \"Y\"\n",
      "      type {\n",
      "        tensor_type {\n",
      "          elem_type: 1\n",
      "          shape {\n",
      "            dim {\n",
      "              dim_value: 1\n",
      "            }\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  type: GRAPHS\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# An node that is also a graph\n",
    "graph_proto = helper.make_graph(\n",
    "    [\n",
    "        helper.make_node(\"FC\", [\"X\", \"W1\", \"B1\"], [\"H1\"]),\n",
    "        helper.make_node(\"Relu\", [\"H1\"], [\"R1\"]),\n",
    "        helper.make_node(\"FC\", [\"R1\", \"W2\", \"B2\"], [\"Y\"]),\n",
    "    ],\n",
    "    \"MLP\",\n",
    "    [\n",
    "        helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('W1', TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('B1', TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('W2', TensorProto.FLOAT, [1]),\n",
    "        helper.make_tensor_value_info('B2', TensorProto.FLOAT, [1]),\n",
    "    ],\n",
    "    [\n",
    "        helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# output = ThisSpecificgraph([input, w1, b1, w2, b2])\n",
    "node_proto = helper.make_node(\n",
    "    \"graph\",\n",
    "    [\"Input\", \"W1\", \"B1\", \"W2\", \"B2\"],\n",
    "    [\"Output\"],\n",
    "    graph=[graph_proto],\n",
    ")\n",
    "\n",
    "print(\"\\nNodeProto that contains a graph:\\n\")\n",
    "print(node_proto)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
